Wasserstein GAN with Gradient Penalty (WGAN-GP)

Goals

In this notebook, you're going to build a Wasserstein GAN with Gradient Penalty (WGAN-GP) that solves some of the stability issues with the GANs that you have been using up until this point. Specifically, you'll use a special kind of loss function known as the W-loss, where W stands for Wasserstein, and gradient penalties to prevent mode collapse.

Fun Fact: Wasserstein is named after a mathematician at Penn State, Leonid Vaseršteĭn. You'll see it abbreviated to W (e.g. WGAN, W-loss, W-distance).

Learning Objectives

  1. Get hands-on experience building a more stable GAN: Wasserstein GAN with Gradient Penalty (WGAN-GP).
  2. Train the more advanced WGAN-GP model.

Generator and Critic

You will begin by importing some useful packages, defining visualization functions, building the generator, and building the critic. Since the changes for WGAN-GP are done to the loss function during training, you can simply reuse your previous GAN code for the generator and critic class. Remember that in WGAN-GP, you no longer use a discriminator that classifies fake and real as 0 and 1 but rather a critic that scores images with real numbers.

Packages and Visualizations

In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for testing purposes, please do not change!

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

def make_grad_hook():
    '''
    Function to keep track of gradients for visualization purposes, 
    which fills the grads list when using model.apply(grad_hook).
    '''
    grads = []
    def grad_hook(m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
            grads.append(m.weight.grad)
    return grads, grad_hook

Generator and Noise

In [2]:
class Generator(nn.Module):
    '''
    Generator Class
    Values:
        z_dim: the dimension of the noise vector, a scalar
        im_chan: the number of channels in the images, fitted for the dataset used, a scalar
              (MNIST is black-and-white, so 1 channel is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, z_dim=10, im_chan=1, hidden_dim=64):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a generator block of DCGAN;
        a transposed convolution, a batchnorm (except in the final layer), and an activation.
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise):
        '''
        Function for completing a forward pass of the generator: Given a noise tensor,
        returns generated images.
        Parameters:
            noise: a noise tensor with dimensions (n_samples, z_dim)
        '''
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

def get_noise(n_samples, z_dim, device='cpu'):
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, z_dim)
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
      n_samples: the number of samples to generate, a scalar
      z_dim: the dimension of the noise vector, a scalar
      device: the device type
    '''
    return torch.randn(n_samples, z_dim, device=device)

Critic

In [3]:
class Critic(nn.Module):
    '''
    Critic Class
    Values:
        im_chan: the number of channels in the images, fitted for the dataset used, a scalar
              (MNIST is black-and-white, so 1 channel is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, im_chan=1, hidden_dim=64):
        super(Critic, self).__init__()
        self.crit = nn.Sequential(
            self.make_crit_block(im_chan, hidden_dim),
            self.make_crit_block(hidden_dim, hidden_dim * 2),
            self.make_crit_block(hidden_dim * 2, 1, final_layer=True),
        )

    def make_crit_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a critic block of DCGAN;
        a convolution, a batchnorm (except in the final layer), and an activation (except in the final layer).
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        '''
        Function for completing a forward pass of the critic: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with dimension (im_chan)
        '''
        crit_pred = self.crit(image)
        return crit_pred.view(len(crit_pred), -1)

Training Initializations

Now you can start putting it all together. As usual, you will start by setting the parameters:

  • n_epochs: the number of times you iterate through the entire dataset when training
  • z_dim: the dimension of the noise vector
  • display_step: how often to display/visualize the images
  • batch_size: the number of images per forward/backward pass
  • lr: the learning rate
  • beta_1, beta_2: the momentum terms
  • c_lambda: weight of the gradient penalty
  • crit_repeats: number of times to update the critic per generator update - there are more details about this in the Putting It All Together section
  • device: the device type

You will also load and transform the MNIST dataset to tensors.

In [4]:
n_epochs = 100
z_dim = 64
display_step = 50
batch_size = 128
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.999
c_lambda = 10
crit_repeats = 5
device = 'cuda'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataloader = DataLoader(
    MNIST('.', download=False, transform=transform),
    batch_size=batch_size,
    shuffle=True)

Then, you can initialize your generator, critic, and optimizers.

In [5]:
gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr, betas=(beta_1, beta_2))
crit = Critic().to(device) 
crit_opt = torch.optim.Adam(crit.parameters(), lr=lr, betas=(beta_1, beta_2))

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
crit = crit.apply(weights_init)

Gradient Penalty

Calculating the gradient penalty can be broken into two functions: (1) compute the gradient with respect to the images and (2) compute the gradient penalty given the gradient.

You can start by getting the gradient. The gradient is computed by first creating a mixed image. This is done by weighing the fake and real image using epsilon and then adding them together. Once you have the intermediate image, you can get the critic's output on the image. Finally, you compute the gradient of the critic score's on the mixed images (output) with respect to the pixels of the mixed images (input). You will need to fill in the code to get the gradient wherever you see None. There is a test function in the next block for you to test your solution.

In [6]:
# UNQ_C1 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_gradient
def get_gradient(crit, real, fake, epsilon):
    '''
    Return the gradient of the critic's scores with respect to mixes of real and fake images.
    Parameters:
        crit: the critic model
        real: a batch of real images
        fake: a batch of fake images
        epsilon: a vector of the uniformly random proportions of real/fake per mixed image
    Returns:
        gradient: the gradient of the critic's scores, with respect to the mixed image
    '''
    # Mix the images together
    mixed_images = real * epsilon + fake * (1 - epsilon)

    # Calculate the critic's scores on the mixed images
    mixed_scores = crit(mixed_images)
    
    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        # Note: You need to take the gradient of outputs with respect to inputs.
        # This documentation may be useful, but it should not be necessary:
        # https://pytorch.org/docs/stable/autograd.html#torch.autograd.grad
        #### START CODE HERE ####
        inputs = mixed_images,
        outputs = mixed_scores,
        #### END CODE HERE ####
        # These other parameters have to do with the pytorch autograd engine works
        grad_outputs=torch.ones_like(mixed_scores), 
        create_graph=True,
        retain_graph=True,
    )[0]
    return gradient
In [7]:
# UNIT TEST
# DO NOT MODIFY THIS
def test_get_gradient(image_shape):
    real = torch.randn(*image_shape, device=device) + 1
    fake = torch.randn(*image_shape, device=device) - 1
    epsilon_shape = [1 for _ in image_shape]
    epsilon_shape[0] = image_shape[0]
    epsilon = torch.rand(epsilon_shape, device=device).requires_grad_()
    gradient = get_gradient(crit, real, fake, epsilon)
    assert tuple(gradient.shape) == image_shape
    assert gradient.max() > 0
    assert gradient.min() < 0
    return gradient

gradient = test_get_gradient((256, 1, 28, 28))
print("Success!")
Success!

The second function you need to complete is to compute the gradient penalty given the gradient. First, you calculate the magnitude of each image's gradient. The magnitude of a gradient is also called the norm. Then, you calculate the penalty by squaring the distance between each magnitude and the ideal norm of 1 and taking the mean of all the squared distances.

Again, you will need to fill in the code wherever you see None. There are hints below that you can view if you need help and there is a test function in the next block for you to test your solution.

Optional hints for gradient_penalty 1. Make sure you take the mean at the end. 2. Note that the magnitude of each gradient has already been calculated for you.
In [8]:
# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: gradient_penalty
def gradient_penalty(gradient):
    '''
    Return the gradient penalty, given a gradient.
    Given a batch of image gradients, you calculate the magnitude of each image's gradient
    and penalize the mean quadratic distance of each magnitude to 1.
    Parameters:
        gradient: the gradient of the critic's scores, with respect to the mixed image
    Returns:
        penalty: the gradient penalty
    '''
    # Flatten the gradients so that each row captures one image
    gradient = gradient.view(len(gradient), -1)

    # Calculate the magnitude of every row
    gradient_norm = gradient.norm(2, dim=1)
    
    # Penalize the mean squared distance of the gradient norms from 1
    #### START CODE HERE ####
    penalty = torch.mean((gradient_norm - 1)**2)
    #### END CODE HERE ####
    return penalty
In [9]:
# UNIT TEST
def test_gradient_penalty(image_shape):
    bad_gradient = torch.zeros(*image_shape)
    bad_gradient_penalty = gradient_penalty(bad_gradient)
    assert torch.isclose(bad_gradient_penalty, torch.tensor(1.))

    image_size = torch.prod(torch.Tensor(image_shape[1:]))
    good_gradient = torch.ones(*image_shape) / torch.sqrt(image_size)
    good_gradient_penalty = gradient_penalty(good_gradient)
    assert torch.isclose(good_gradient_penalty, torch.tensor(0.))

    random_gradient = test_get_gradient(image_shape)
    random_gradient_penalty = gradient_penalty(random_gradient)
    assert torch.abs(random_gradient_penalty - 1) < 0.1

test_gradient_penalty((256, 1, 28, 28))
print("Success!")
Success!

Losses

Next, you need to calculate the loss for the generator and the critic.

For the generator, the loss is calculated by maximizing the critic's prediction on the generator's fake images. The argument has the scores for all fake images in the batch, but you will use the mean of them.

There are optional hints below and a test function in the next block for you to test your solution.

Optional hints for get_gen_loss 1. This can be written in one line. 2. This is the negative of the mean of the critic's scores.
In [10]:
# UNQ_C3 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_gen_loss
def get_gen_loss(crit_fake_pred):
    '''
    Return the loss of a generator given the critic's scores of the generator's fake images.
    Parameters:
        crit_fake_pred: the critic's scores of the fake images
    Returns:
        gen_loss: a scalar loss value for the current batch of the generator
    '''
    #### START CODE HERE ####
    gen_loss = -1. * torch.mean(crit_fake_pred)
    #### END CODE HERE ####
    return gen_loss
In [11]:
# UNIT TEST
assert torch.isclose(
    get_gen_loss(torch.tensor(1.)), torch.tensor(-1.0)
)

assert torch.isclose(
    get_gen_loss(torch.rand(10000)), torch.tensor(-0.5), 0.05
)

print("Success!")
Success!

For the critic, the loss is calculated by maximizing the distance between the critic's predictions on the real images and the predictions on the fake images while also adding a gradient penalty. The gradient penalty is weighed according to lambda. The arguments are the scores for all the images in the batch, and you will use the mean of them.

There are hints below if you get stuck and a test function in the next block for you to test your solution.

Optional hints for get_crit_loss 1. The higher the mean fake score, the higher the critic's loss is. 2. What does this suggest about the mean real score? 3. The higher the gradient penalty, the higher the critic's loss is, proportional to lambda.
In [12]:
# UNQ_C4 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: get_crit_loss
def get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda):
    '''
    Return the loss of a critic given the critic's scores for fake and real images,
    the gradient penalty, and gradient penalty weight.
    Parameters:
        crit_fake_pred: the critic's scores of the fake images
        crit_real_pred: the critic's scores of the real images
        gp: the unweighted gradient penalty
        c_lambda: the current weight of the gradient penalty 
    Returns:
        crit_loss: a scalar for the critic's loss, accounting for the relevant factors
    '''
    #### START CODE HERE ####
    crit_loss = torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + c_lambda * gp
    #### END CODE HERE ####
    return crit_loss
In [13]:
# UNIT TEST
assert torch.isclose(
    get_crit_loss(torch.tensor(1.), torch.tensor(2.), torch.tensor(3.), 0.1),
    torch.tensor(-0.7)
)
assert torch.isclose(
    get_crit_loss(torch.tensor(20.), torch.tensor(-20.), torch.tensor(2.), 10),
    torch.tensor(60.)
)

print("Success!")
Success!

Putting It All Together

Before you put everything together, there are a few things to note.

  1. Even on GPU, the training will run more slowly than previous labs because the gradient penalty requires you to compute the gradient of a gradient -- this means potentially a few minutes per epoch! For best results, run this for as long as you can while on GPU.
  2. One important difference from earlier versions is that you will update the critic multiple times every time you update the generator This helps prevent the generator from overpowering the critic. Sometimes, you might see the reverse, with the generator updated more times than the critic. This depends on architectural (e.g. the depth and width of the network) and algorithmic choices (e.g. which loss you're using).
  3. WGAN-GP isn't necessarily meant to improve overall performance of a GAN, but just increases stability and avoids mode collapse. In general, a WGAN will be able to train in a much more stable way than the vanilla DCGAN from last assignment, though it will generally run a bit slower. You should also be able to train your model for more epochs without it collapsing.

Here is a snapshot of what your WGAN-GP outputs should resemble: MNIST Digits Progression

In [ ]:
import matplotlib.pyplot as plt

cur_step = 0
generator_losses = []
critic_losses = []
for epoch in range(n_epochs):
    # Dataloader returns the batches
    for real, _ in tqdm(dataloader):
        cur_batch_size = len(real)
        real = real.to(device)

        mean_iteration_critic_loss = 0
        for _ in range(crit_repeats):
            ### Update critic ###
            crit_opt.zero_grad()
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            fake = gen(fake_noise)
            crit_fake_pred = crit(fake.detach())
            crit_real_pred = crit(real)

            epsilon = torch.rand(len(real), 1, 1, 1, device=device, requires_grad=True)
            gradient = get_gradient(crit, real, fake.detach(), epsilon)
            gp = gradient_penalty(gradient)
            crit_loss = get_crit_loss(crit_fake_pred, crit_real_pred, gp, c_lambda)

            # Keep track of the average critic loss in this batch
            mean_iteration_critic_loss += crit_loss.item() / crit_repeats
            # Update gradients
            crit_loss.backward(retain_graph=True)
            # Update optimizer
            crit_opt.step()
        critic_losses += [mean_iteration_critic_loss]

        ### Update generator ###
        gen_opt.zero_grad()
        fake_noise_2 = get_noise(cur_batch_size, z_dim, device=device)
        fake_2 = gen(fake_noise_2)
        crit_fake_pred = crit(fake_2)
        
        gen_loss = get_gen_loss(crit_fake_pred)
        gen_loss.backward()

        # Update the weights
        gen_opt.step()

        # Keep track of the average generator loss
        generator_losses += [gen_loss.item()]

        ### Visualization code ###
        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            crit_mean = sum(critic_losses[-display_step:]) / display_step
            print(f"Step {cur_step}: Generator loss: {gen_mean}, critic loss: {crit_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(critic_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Critic Loss"
            )
            plt.legend()
            plt.show()

        cur_step += 1
Step 50: Generator loss: -0.08869026243221015, critic loss: 1.8815238857567311
Step 100: Generator loss: 1.0189525532722472, critic loss: -2.000770281553269
Step 150: Generator loss: 3.1063385939598085, critic loss: -10.454776483535763
Step 200: Generator loss: -0.03096360892057419, critic loss: -24.949981227874755
Step 250: Generator loss: -2.1902864092588423, critic loss: -54.09114030456542
Step 300: Generator loss: -1.9651418948173522, critic loss: -89.99565879821776
Step 350: Generator loss: 1.5058544285595417, critic loss: -135.3717141113281
Step 400: Generator loss: 5.135000143051148, critic loss: -184.69015820312507
Step 450: Generator loss: 7.391116046905518, critic loss: -234.38668341064454

Step 500: Generator loss: 9.434671916961669, critic loss: -291.52044934082016
Step 550: Generator loss: 10.198309473991394, critic loss: -339.4726991119385
Step 600: Generator loss: 9.59009435236454, critic loss: -348.55542071533205
Step 650: Generator loss: -14.686983901262284, critic loss: -326.99507629394526
Step 700: Generator loss: -4.207686264514923, critic loss: -348.30472375488273
Step 750: Generator loss: 4.711256952285766, critic loss: -364.4952556762695
Step 800: Generator loss: -11.032265238761902, critic loss: -333.9789097900391
Step 850: Generator loss: -4.984796676635742, critic loss: -315.2135946044922
Step 900: Generator loss: -12.270385609865189, critic loss: -260.5971811218261

Step 950: Generator loss: -26.233554520010948, critic loss: -157.75049143218993
Step 1000: Generator loss: -99.86857343673707, critic loss: -175.97676026916503
Step 1050: Generator loss: -79.90346534729004, critic loss: -148.01821292114258
Step 1100: Generator loss: -33.54377632856369, critic loss: -128.10571365356446
Step 1150: Generator loss: 6.24583690404892, critic loss: 41.862205413818366
Step 1200: Generator loss: -31.200696147680283, critic loss: -16.397828117370594
Step 1250: Generator loss: -26.244392824172973, critic loss: -85.41039065551759
Step 1300: Generator loss: -42.6310830783844, critic loss: -58.39235194396974
Step 1350: Generator loss: 22.27627010345459, critic loss: 58.508746185302755
Step 1400: Generator loss: 22.241844139099122, critic loss: 82.20128218078612

Step 1450: Generator loss: 20.323241462707518, critic loss: 30.57237985992431
Step 1500: Generator loss: 19.14805164337158, critic loss: 16.62121188354492
Step 1550: Generator loss: 20.842887496948244, critic loss: 24.441767265319832
Step 1600: Generator loss: 21.71275493621826, critic loss: 31.92951148986816
Step 1650: Generator loss: 21.993818740844727, critic loss: 28.329695838928227
Step 1700: Generator loss: 21.36519245147705, critic loss: 21.684253944396975
Step 1750: Generator loss: 16.061803646087647, critic loss: 16.976108528137203
Step 1800: Generator loss: 12.668681163787841, critic loss: 8.760622522354126
Step 1850: Generator loss: 11.7738698387146, critic loss: 5.508992042541505

Step 1900: Generator loss: 11.233222122192382, critic loss: 4.613992996215821
Step 1950: Generator loss: 9.827853403091432, critic loss: 2.3955820655822744
Step 2000: Generator loss: 7.088875589370727, critic loss: -1.6258588695526124
Step 2050: Generator loss: 4.513366765975952, critic loss: -9.096851562499998
Step 2100: Generator loss: 5.586472408175468, critic loss: -15.782199195861818
Step 2150: Generator loss: 4.169536139369011, critic loss: -20.749746749877932
Step 2200: Generator loss: 4.396659771353006, critic loss: -25.34181425476075
Step 2250: Generator loss: 4.740511154532433, critic loss: -26.930474826812752
Step 2300: Generator loss: 3.128926914334297, critic loss: -29.72227243041992

Step 2350: Generator loss: 2.0598813822865485, critic loss: -28.949868324279777
Step 2400: Generator loss: 2.326464270055294, critic loss: -34.715240131378174
Step 2450: Generator loss: -0.25378808498382566, critic loss: -38.17930244445801
Step 2500: Generator loss: -0.06834467113018036, critic loss: -34.64415988349914
Step 2550: Generator loss: 1.8529850709438325, critic loss: -1.221760829925537
Step 2600: Generator loss: -0.4838709855079651, critic loss: -5.479033050537109
Step 2650: Generator loss: -11.033549275398254, critic loss: -25.38684425735474
Step 2700: Generator loss: -5.381214776039124, critic loss: -27.842150806427004
Step 2750: Generator loss: -3.8478834027051927, critic loss: -32.2803116350174
Step 2800: Generator loss: 0.35291880428791045, critic loss: -1.3221866350173943

Step 2850: Generator loss: -5.173080644756555, critic loss: -18.545050124168398
Step 2900: Generator loss: -8.65681962966919, critic loss: -34.319975078582765
Step 2950: Generator loss: -5.557834906578064, critic loss: -35.752119188308725
Step 3000: Generator loss: -5.460672886967659, critic loss: -31.965695253372193
Step 3050: Generator loss: -0.7373168742656708, critic loss: -23.608568210601817
Step 3100: Generator loss: -8.439535624980927, critic loss: -40.217682380676266
Step 3150: Generator loss: -5.403269990682602, critic loss: -32.19317657375335
Step 3200: Generator loss: -4.637202978730202, critic loss: -38.708125661849984
Step 3250: Generator loss: -8.382855553627014, critic loss: -38.48140592193604

Step 3300: Generator loss: -7.543863979578018, critic loss: -37.57716330718993
Step 3350: Generator loss: -5.832048915922642, critic loss: -28.588103427886963
Step 3400: Generator loss: 2.8242369401454925, critic loss: 2.1781655473709094
Step 3450: Generator loss: 0.5923665107786655, critic loss: -2.004360100269318
Step 3500: Generator loss: -3.3256619572639465, critic loss: -1.642112728118896
Step 3550: Generator loss: -5.595781497955322, critic loss: -1.928438231945038
Step 3600: Generator loss: -6.504997472763062, critic loss: -2.3893724865913395
Step 3650: Generator loss: -6.714081335067749, critic loss: -2.8723923883438114
Step 3700: Generator loss: -6.789855241775513, critic loss: -3.6496537094116217
Step 3750: Generator loss: -6.174319705963135, critic loss: -5.485189699172974

Step 3800: Generator loss: -6.597644319534302, critic loss: -9.470111770629885
Step 3850: Generator loss: -7.711389700770378, critic loss: -16.416760327339173
Step 3900: Generator loss: -3.090485827922821, critic loss: -11.075554036140442
Step 3950: Generator loss: -10.817276525497437, critic loss: -25.03645037651063
Step 4000: Generator loss: -9.933663418516517, critic loss: -24.93167032527924
Step 4050: Generator loss: -9.60056332230568, critic loss: -25.561650232315074
Step 4100: Generator loss: -1.7821738021075726, critic loss: -1.9863791446685792
Step 4150: Generator loss: -6.035529971122742, critic loss: -5.458259954452514
Step 4200: Generator loss: -17.014216923713683, critic loss: -25.492203227043156

Step 4250: Generator loss: -11.383503601551055, critic loss: -24.473569443702686
Step 4300: Generator loss: -15.300645132064819, critic loss: -28.337488735198963
Step 4350: Generator loss: -17.43895546913147, critic loss: -27.801260742187495
Step 4400: Generator loss: -8.768499946594238, critic loss: -9.016045122146608
Step 4450: Generator loss: -12.64761972784996, critic loss: -15.587988515853885
Step 4500: Generator loss: -14.287293720245362, critic loss: -25.75146458435058
Step 4550: Generator loss: -16.019245283603667, critic loss: -29.79406247901916
Step 4600: Generator loss: -13.957903513908386, critic loss: -20.346154007911686
Step 4650: Generator loss: -14.506572031974793, critic loss: -26.462586246490478

Step 4700: Generator loss: -15.950510745048524, critic loss: -26.852716032028205
Step 4750: Generator loss: -5.8097316670417785, critic loss: -4.8752821702957165
Step 4800: Generator loss: -7.61875652551651, critic loss: -10.430500278472902
Step 4850: Generator loss: -15.635288712978364, critic loss: -25.552565708160405
Step 4900: Generator loss: -18.178771851062773, critic loss: -22.525639303207395
Step 4950: Generator loss: -11.000518074035645, critic loss: -11.918203362464906
Step 5000: Generator loss: -7.161812973022461, critic loss: -3.6630326747894277
Step 5050: Generator loss: -14.269008483886719, critic loss: -9.641291046142577
Step 5100: Generator loss: -15.17794798374176, critic loss: -1.6865589847564675
Step 5150: Generator loss: -6.9874586009979245, critic loss: -2.763552607536316

Step 5200: Generator loss: -10.342125186920166, critic loss: -2.6698714928627014
Step 5250: Generator loss: -12.982628784179688, critic loss: -2.965361077785492
Step 5300: Generator loss: -13.923907432556152, critic loss: -3.394012515068054
Step 5350: Generator loss: -14.91234230041504, critic loss: -4.472306784629823
Step 5400: Generator loss: -22.589904861450194, critic loss: -20.283349450111388
Step 5450: Generator loss: -18.26389296531677, critic loss: -16.131568618774413
Step 5500: Generator loss: -18.91062386035919, critic loss: -21.24103550720215
Step 5550: Generator loss: -19.348641214966776, critic loss: -15.003384853839876
Step 5600: Generator loss: -7.568294858932495, critic loss: -3.2691667375564584

Step 5650: Generator loss: -11.761536922454834, critic loss: -3.9016947326660163
Step 5700: Generator loss: -17.730436191558837, critic loss: -10.575335399627688
Step 5750: Generator loss: -15.389623624682427, critic loss: -9.952241210937494
Step 5800: Generator loss: -8.180975332260132, critic loss: -3.626739371299743
Step 5850: Generator loss: -13.673976650238037, critic loss: -4.6690915994644175
Step 5900: Generator loss: -23.423332042694092, critic loss: -19.006739682197566
Step 5950: Generator loss: -20.846789140701294, critic loss: -23.722263383865357
Step 6000: Generator loss: -14.352950210571288, critic loss: -11.94288802623749
Step 6050: Generator loss: -17.197109088897705, critic loss: -15.024040604591368

Step 6100: Generator loss: -17.678665273189544, critic loss: -19.678323865890505
Step 6150: Generator loss: -5.51946144580841, critic loss: 1.422172403335569
Step 6200: Generator loss: -5.6169371414184575, critic loss: -3.0744575681686404
Step 6250: Generator loss: -9.507413730621337, critic loss: -3.4173239688873287
Step 6300: Generator loss: -13.570071792602539, critic loss: -6.6362248277664175
Step 6350: Generator loss: -9.134255447387694, critic loss: 0.8062635078430179
Step 6400: Generator loss: -7.847699356079102, critic loss: -2.8099637699127196
Step 6450: Generator loss: -10.073278160095215, critic loss: -3.3017093381881715
Step 6500: Generator loss: -10.9572945022583, critic loss: -3.9077557072639455
Step 6550: Generator loss: -16.45351732969284, critic loss: -14.321979390621182

Step 6600: Generator loss: -12.738411252498627, critic loss: -0.09198231744766405
Step 6650: Generator loss: -3.051986961364746, critic loss: -2.744908696651459
Step 6700: Generator loss: -6.664891853332519, critic loss: -2.5603360729217526
Step 6750: Generator loss: -7.8842816734313965, critic loss: -2.6821388635635377
Step 6800: Generator loss: -7.849865465164185, critic loss: -2.8346559019088744
Step 6850: Generator loss: -8.349686765670777, critic loss: -3.0016350402832033
Step 6900: Generator loss: -9.607725353240967, critic loss: -3.1205613212585446
Step 6950: Generator loss: -10.71149227142334, critic loss: -3.2234377617836016
Step 7000: Generator loss: -10.954191932678222, critic loss: -4.686962251663208

Step 7050: Generator loss: -8.436050066947937, critic loss: -2.239833469867707
Step 7100: Generator loss: -10.932452278137207, critic loss: -3.895902329444885
Step 7150: Generator loss: -10.667149393558502, critic loss: -6.043405599594115
Step 7200: Generator loss: -11.712564096450805, critic loss: -8.265575465202332
Step 7250: Generator loss: -8.328836711347103, critic loss: -6.279276549816129
Step 7300: Generator loss: -6.195756266117096, critic loss: -5.744718703269959
Step 7350: Generator loss: -15.353437192440033, critic loss: -16.115581097602846
Step 7400: Generator loss: -15.742991452217103, critic loss: -18.675537931442264
Step 7450: Generator loss: -8.320359216928482, critic loss: -13.353263663291933
Step 7500: Generator loss: -11.154643761515617, critic loss: -17.828186235427857

Step 7550: Generator loss: -3.1160362350940702, critic loss: -5.032801697731019
Step 7600: Generator loss: -4.732586090564728, critic loss: -9.268116512298585
Step 7650: Generator loss: -3.844493716955185, critic loss: -7.1622213621139545
Step 7700: Generator loss: -6.2145571875572205, critic loss: -7.966237799644469
Step 7750: Generator loss: -6.341645574271679, critic loss: -11.959992330551152
Step 7800: Generator loss: -9.418573129177094, critic loss: -15.580542749404904
Step 7850: Generator loss: -4.848915230631828, critic loss: -14.312010268211363
Step 7900: Generator loss: -7.77654335141182, critic loss: -15.854042816162107
Step 7950: Generator loss: -8.576519481539727, critic loss: -11.021303612709044

Step 8000: Generator loss: -4.719242013692856, critic loss: -11.11103640270233
Step 8050: Generator loss: -6.0516835153102875, critic loss: -12.72197005558014
Step 8100: Generator loss: -3.2008099579811096, critic loss: -14.758851803779601
Step 8150: Generator loss: 0.3244405049085617, critic loss: -12.399096450805661
Step 8200: Generator loss: -2.8401265180110933, critic loss: -15.013349224090575
Step 8250: Generator loss: 4.442987043857574, critic loss: -3.980695918083192
Step 8300: Generator loss: 5.22506322324276, critic loss: -4.466050217628479
Step 8350: Generator loss: -5.91634309887886, critic loss: -14.373947637557979
Step 8400: Generator loss: -3.543821589946747, critic loss: -15.954828742980954

Step 8450: Generator loss: 0.9829114472866058, critic loss: -15.809339665412903
Step 8500: Generator loss: 1.6671456277370453, critic loss: -5.812027407169342
Step 8550: Generator loss: 3.7685218435525893, critic loss: -12.108838273048402
Step 8600: Generator loss: 0.8718191981315613, critic loss: -14.347978047132488
Step 8650: Generator loss: 2.8619570606946945, critic loss: -12.882514192581176
Step 8700: Generator loss: 5.4293956917524335, critic loss: -12.452005095720292
Step 8750: Generator loss: -2.1058886897563935, critic loss: -10.724813097476959
Step 8800: Generator loss: -0.32922265529632566, critic loss: -12.968225153923036
Step 8850: Generator loss: 3.0781015133857728, critic loss: -14.396550803184512
Step 8900: Generator loss: 5.7408320564031605, critic loss: -14.785713949203489

Step 8950: Generator loss: 12.829762142598629, critic loss: -12.201240473747259
Step 9000: Generator loss: 6.454181790351868, critic loss: -8.54612139034271
Step 9050: Generator loss: 6.633567657470703, critic loss: -3.913219250679014
Step 9100: Generator loss: 0.3205719903111458, critic loss: -14.025861147880553
Step 9150: Generator loss: 9.336125988066197, critic loss: -11.034634428024294
Step 9200: Generator loss: 3.5905175483226777, critic loss: -16.49209539937973
Step 9250: Generator loss: 10.085442260503768, critic loss: -13.059457175731659
Step 9300: Generator loss: 9.493704957962036, critic loss: -15.21216254377365
Step 9350: Generator loss: 7.436434207558632, critic loss: -15.097099609851844

Step 9400: Generator loss: 12.055890271663666, critic loss: -3.772059518337248
Step 9450: Generator loss: 12.08195770263672, critic loss: 2.123932291030884
Step 9500: Generator loss: 13.062031269073486, critic loss: -0.31967130613327027
Step 9550: Generator loss: 13.092384490966797, critic loss: -2.279303745746612
Step 9600: Generator loss: 14.222563571929932, critic loss: -3.5600462260246273
Step 9650: Generator loss: 11.661315126419067, critic loss: -6.065145170688631
Step 9700: Generator loss: 12.200704522132874, critic loss: -7.031322331905365
Step 9750: Generator loss: 10.650458063483239, critic loss: -9.190834596633911
Step 9800: Generator loss: 12.14132619380951, critic loss: -9.531413751125335

Step 9850: Generator loss: 13.63720681667328, critic loss: -10.47444561767578
Step 9900: Generator loss: 17.503834223747255, critic loss: -5.384989888429642
Step 9950: Generator loss: 13.614862778186797, critic loss: -9.991226898193363
Step 10000: Generator loss: 7.091982510089874, critic loss: -14.685286004066475
Step 10050: Generator loss: 15.369895470142364, critic loss: -13.914708140850063
Step 10100: Generator loss: 13.818947200775147, critic loss: -5.275847900152208
Step 10150: Generator loss: 16.769601423740387, critic loss: -7.283954842090603
Step 10200: Generator loss: 13.539169175624847, critic loss: -7.575226979255677
Step 10250: Generator loss: 12.903843214511872, critic loss: -13.834916482448582
Step 10300: Generator loss: 13.824523911178112, critic loss: -11.565740355968474

Step 10350: Generator loss: 23.126209487915037, critic loss: -3.5065316848754877
Step 10400: Generator loss: 11.144089060425758, critic loss: -13.804089553833006
Step 10450: Generator loss: 16.73342710494995, critic loss: -16.60548991680145
Step 10500: Generator loss: 14.49173572719097, critic loss: -16.303591802120206
Step 10550: Generator loss: 15.281622717380523, critic loss: -16.646503752708433
Step 10600: Generator loss: 19.046842958927154, critic loss: -13.268205133438116
Step 10650: Generator loss: 24.31525261491537, critic loss: -12.346739721775055
Step 10700: Generator loss: 16.24910553455353, critic loss: -16.35200205373764
Step 10750: Generator loss: 19.570902199745177, critic loss: -15.598573707580567

Step 10800: Generator loss: 15.220140396356582, critic loss: -16.457515287160874
Step 10850: Generator loss: 20.54370219707489, critic loss: -5.76649915742874
Step 10900: Generator loss: 25.237007217407225, critic loss: -4.718336683273314
Step 10950: Generator loss: 17.643479496240616, critic loss: -14.813091954231261
Step 11000: Generator loss: 17.931723778247832, critic loss: -16.6928783416748
Step 11050: Generator loss: 15.076536083221436, critic loss: -17.14448805379868
Step 11100: Generator loss: 21.38611768901348, critic loss: -14.199283111572264
Step 11150: Generator loss: 15.369360786676406, critic loss: -15.566518260478972
Step 11200: Generator loss: 17.964607070684433, critic loss: -19.536781085014344
Step 11250: Generator loss: 19.02104866027832, critic loss: -17.502904921531677

Step 11300: Generator loss: 16.893439546525478, critic loss: -21.1421114025116
Step 11350: Generator loss: 17.84908836364746, critic loss: -14.93827541065216
Step 11400: Generator loss: 15.798309868872165, critic loss: -21.581478227615364
Step 11450: Generator loss: 23.5217720413208, critic loss: -14.291318614482885
Step 11500: Generator loss: 24.097826385498045, critic loss: -13.374350033283237
Step 11550: Generator loss: 16.69866649389267, critic loss: -8.168238547325137
Step 11600: Generator loss: 18.20891289591789, critic loss: -15.527288549184803
Step 11650: Generator loss: 25.441228060722352, critic loss: -12.117457561016082
Step 11700: Generator loss: 23.09286418914795, critic loss: -16.468590665817263

Step 11750: Generator loss: 19.499136786460877, critic loss: -16.775978488445283
Step 11800: Generator loss: 26.939843635559082, critic loss: -13.736246906757351
Step 11850: Generator loss: 19.867035599946977, critic loss: -7.34710036563873
Step 11900: Generator loss: 26.173578137159346, critic loss: -10.837821035385131
Step 11950: Generator loss: 19.16404739379883, critic loss: -19.601039904117584
Step 12000: Generator loss: 25.78880739212036, critic loss: -11.808945682525634
Step 12050: Generator loss: 25.231854648590087, critic loss: -13.578082417488101
Step 12100: Generator loss: 21.96565896987915, critic loss: -16.173608081340788
Step 12150: Generator loss: 20.569176206588747, critic loss: -17.96209384202957

Step 12200: Generator loss: 21.31651187181473, critic loss: -17.643821037769317
Step 12250: Generator loss: 23.543673152923585, critic loss: -6.420106256246568
Step 12300: Generator loss: 28.98377960205078, critic loss: -5.6026851439476015
Step 12350: Generator loss: 23.064338397979736, critic loss: -13.588204405784603
Step 12400: Generator loss: 14.360887625217437, critic loss: -19.073876594066626
Step 12450: Generator loss: 22.12294903278351, critic loss: -19.880444211006164
Step 12500: Generator loss: 19.939283509254455, critic loss: -20.863597845554345
Step 12550: Generator loss: 23.750905470848082, critic loss: -17.536349697828292
Step 12600: Generator loss: 31.971336555480956, critic loss: -5.809852944374083
Step 12650: Generator loss: 13.045195934772492, critic loss: -19.511251863241192

Step 12700: Generator loss: 18.092088370323182, critic loss: -18.151690407276153
Step 12750: Generator loss: 31.612741775512696, critic loss: 2.7522802910804747
Step 12800: Generator loss: 20.272375016212465, critic loss: -14.69069622564316
Step 12850: Generator loss: 17.574717574119568, critic loss: -19.360476994991295
Step 12900: Generator loss: 25.326936082839964, critic loss: -15.019393275976183
Step 12950: Generator loss: 22.579223737716674, critic loss: -15.558635712623595
Step 13000: Generator loss: 22.331913833618163, critic loss: -18.480101360797878
Step 13050: Generator loss: 24.743406887054444, critic loss: -16.38874498939514
Step 13100: Generator loss: 23.625229470729828, critic loss: -17.396287363052362

Step 13150: Generator loss: 27.520250186920165, critic loss: -12.830729541778561
Step 13200: Generator loss: 27.969804508686067, critic loss: -13.88786848115921
Step 13250: Generator loss: 21.9987981569767, critic loss: -16.44161170268059
Step 13300: Generator loss: 24.94612365245819, critic loss: -19.222108069896702
Step 13350: Generator loss: 21.622284948825836, critic loss: -16.349706367731095
Step 13400: Generator loss: 17.488346133232117, critic loss: -14.18832929229737
Step 13450: Generator loss: 18.58687307357788, critic loss: -22.756161051750176
Step 13500: Generator loss: 28.092126989364623, critic loss: -15.146699063301087
Step 13550: Generator loss: 17.88306753873825, critic loss: -19.877120500564573
Step 13600: Generator loss: 27.77892411708832, critic loss: -14.512011850357055

Step 13650: Generator loss: 24.701718806028367, critic loss: -6.653988728523253
Step 13700: Generator loss: 24.075694761276246, critic loss: -12.945459551334379
Step 13750: Generator loss: 14.400666347444057, critic loss: -23.600848259925844
Step 13800: Generator loss: 27.03461096525192, critic loss: -14.083832769870758
Step 13850: Generator loss: 24.518703536987303, critic loss: -17.960850058317185
Step 13900: Generator loss: 25.687818822860716, critic loss: -4.2380144739151
In [ ]: